% This script receives the segmented parts and returns the results of the
% classification (1 is valid, 0 is not)

function is_valid = validateSegmentationResults(im_dg_in, im_dg_out_int, im_dg_out_ext, im_dg_out, im_ca1_in, im_ca1_out, im_ca3_in, im_ca3_out)

    addpaths;

    %this file contains the trained random forest to classify correct
    %segmentations
    load rf_model_onlytexture

    % mean, standard deviation and other parameters used to normalize the
    % features for the random forest
    load('std_stand','std_stand')
    load('mean_stand','mean_stand')
    load('um', 'um')
    load('ustd','ustd')
    NUMBER_OF_POINTS_PER_PART = 12;
    WIN_SIZES = [30, 90, 150, 210,300];

    mask_ca1 = im2bw(im_ca1_in,0.0001);
    mask_ca3 = im2bw(im_ca3_in,0.0001);
    mask_dg = im2bw(im_dg_in,0.0001);
    
    ANALYZED_POINTS = [0,0,0,0,0,0];
    CORRECT_POINTS = [0,0,0,0,0,0];
    in_points = zeros(NUMBER_OF_POINTS_PER_PART*3,2);
    out_points = zeros(NUMBER_OF_POINTS_PER_PART*3,2);
    in_class = zeros(NUMBER_OF_POINTS_PER_PART*3);
    out_class = zeros(NUMBER_OF_POINTS_PER_PART*3);
    for p=1:3 %ca1, ca3, dg
        if p==2 %ca1
            mask_in = mask_ca1;
            im_in = im_ca1_in;
            im_out = im_ca1_out;
            im = im_in + im_out;
            IN_IND = p*2-1;
            OUT_IND = p*2;
            CL_IND = NUMBER_OF_POINTS_PER_PART*p-NUMBER_OF_POINTS_PER_PART; %base index to fill in_points
        elseif p==1 %ca3
            mask_in = mask_ca3;
            im_in = im_ca3_in;
            im_out = im_ca3_out;
            im = im_in + im_out;
            IN_IND = p*2-1;
            OUT_IND = p*2;
            CL_IND = NUMBER_OF_POINTS_PER_PART*p-NUMBER_OF_POINTS_PER_PART;
        else %dg
            mask_in = mask_dg;
            im_in = im_dg_in;
            im_out = im_dg_out;
            im = im_in + im_out;
            IN_IND = p*2-1;
            OUT_IND = p*2;
            CL_IND = NUMBER_OF_POINTS_PER_PART*p-NUMBER_OF_POINTS_PER_PART;
        end
        [rows, columns, levels] = size(im);

        if((size(im,3)==3))
            im = rgb2gray(im);
        end;
        temp=double(im);

        m=mean(temp(:));
        st=std(temp(:));
        im=(temp-m)*ustd/st+um;
        im=uint8(im);

        %too small image, not valid!
        if(rows < 8 || columns < 8)
            is_valid = false;
            return
        end
        InputImage = adapthisteq(im); %If we change this kind of equalization for the classical one the result will be much more noisy
        InputImage(InputImage > 175) = 255;
        InputImage(InputImage < 40) = 0;
        InputImageNew = adapthisteq(InputImage);
        J=InputImageNew;
        level = graythresh(J);
        BW_global = im2bw(J,level);
        BW_final = 1-BW_global;
        BW_final2 = BW_final;
        CC = bwconncomp(BW_final2,8); % A scalar connectivity specifier must be 1, 4, 6, 8, 18, or 26.
        numPixels = cellfun(@numel,CC.PixelIdxList);% number of pixels of each connected component
        [ordered,idx] = sort(numPixels,'descend');
        con_component=4;
        indexes = idx(1:con_component);
        for pp=1:length(numPixels)
            BW_final2(CC.PixelIdxList{pp}) = 255*ismember(pp,indexes);
        end

        number_of_pixels = size(im,1)*size(im,2);
        [COUNTS,X] = imhist(im);


        while sum(ANALYZED_POINTS(IN_IND:OUT_IND)) < NUMBER_OF_POINTS_PER_PART*2
            do_it = 0;
            input_pattern=[];
            y = ceil(rand*columns);
            x = ceil(rand*rows);
            if mask_in(x,y) == 1 && ANALYZED_POINTS(IN_IND) < NUMBER_OF_POINTS_PER_PART
                ANALYZED_POINTS(IN_IND) = ANALYZED_POINTS(IN_IND) + 1;
                in_points(CL_IND+ANALYZED_POINTS(IN_IND),:) = [y x];
                do_it = 1;
            end
            if mask_in(x,y) == 0 && ANALYZED_POINTS(OUT_IND) < NUMBER_OF_POINTS_PER_PART
                ANALYZED_POINTS(OUT_IND) = ANALYZED_POINTS(OUT_IND) + 1;
                out_points(CL_IND+ANALYZED_POINTS(OUT_IND),:) = [y x];
                do_it = 1;
            end
            if ~do_it
                continue;
            end
            for w = 1:length(WIN_SIZES)

                if (round(x-WIN_SIZES(w)/2) < 1)
                    value1=1;
                else
                    value1=x-WIN_SIZES(w)/2;
                end;
                if (round(x+WIN_SIZES(w)/2) > size(im,1))
                    value2=size(im,1);
                else
                    value2=x+WIN_SIZES(w)/2;
                end;
                if (round(y-WIN_SIZES(w)/2) < 1)
                    value3=1;
                else
                    value3=y-WIN_SIZES(w)/2;
                end;
                if (round(y+WIN_SIZES(w)/2) > size(im,2))
                    value4=size(im,2);
                else
                    value4=y+WIN_SIZES(w)/2;
                end;
                window = im(round(value1):round(value2),round(value3):round(value4));

                h = imhist(window,16);
                h=h/sum(h);
                cv = std2(window)/mean2(window);
                skew = skewness(h);
                curt = kurtosis(h);
                ener = sum(h/sum(h).*h/sum(h));
                entr = entropy(h);
                G = graycomatrix(window,'symmetric', true, 'Offset', [1,1],'NumLevels',16);
                feat=graycoprops(G);
                input_pattern = [input_pattern mean2(window), std2(window), cv, skew, curt, ener, entr, feat.Contrast,feat.Correlation,feat.Energy,feat.Homogeneity];
            end

            input_pattern(isnan(input_pattern)) = 0;
            input_pattern = (input_pattern - mean_stand)./std_stand;
            output = classRF_predict(input_pattern, rf_model);
            if mask_in(x,y) == 1
                in_class(CL_IND+ANALYZED_POINTS(IN_IND)) = output;
            else
                out_class(CL_IND+ANALYZED_POINTS(OUT_IND)) = output;
            end
            if (output == 1 && mask_in(x,y) == 1)
                CORRECT_POINTS(IN_IND) = CORRECT_POINTS(IN_IND) + 1;
            end
            if (output == 2 && mask_in(x,y) == 0)
                CORRECT_POINTS(OUT_IND) = CORRECT_POINTS(OUT_IND) + 1;
            end


        end
        
    end
    %% EXTRACT SUPPLEMENTARY FEATURES
    % median, stddev width of CA3
    % median, stddev height of CA1
    med1 = median(mean(mask_ca1));
    std1 = std(mean(mask_ca1));
    med3 = median(mean(mask_ca3,2));
    std3 = std(mean(mask_ca3,2));
    
    % percentage of area of the outer parts (note: DG has two outer parts)
    mask_ca1 = ~mask_ca1;
    mask_ca3 = ~mask_ca3;
    mask_dg_ext = im2bw(im_dg_out_ext,0.0001);
    mask_dg_int = im2bw(im_dg_out_int,0.0001);
    
    clear rf_model
    load rf_model_for_classifying_segmented_parts
    
    features = [CORRECT_POINTS/NUMBER_OF_POINTS_PER_PART mean2(mask_ca3), mean2(mask_ca1), mean2(mask_dg_ext), mean2(mask_dg_int) med3, std3, med1, std1];
    is_valid = classRF_predict(features, rf_model);
    
    clear rf_model
    clear med1
    clear med3
    clear std1
    clear std3
    clear InputImage
    clear InputImageNew
    clear mask_ca1
    clear mask_ca3
    clear mask_dg
    
end